/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef FRAMEWORK_TENSOR_COMPATIBLE_AI_TENSOR_IMPL_H
#define FRAMEWORK_TENSOR_COMPATIBLE_AI_TENSOR_IMPL_H

#include "compatible/AiTensor.h"
#include "tensor/nd_tensor_buffer.h"
#include "tensor/nd_tensor_desc.h"


namespace hiai {
class TensorDimensionImpl {
public:
    TensorDimensionImpl();
    virtual ~TensorDimensionImpl();

    TensorDimensionImpl(uint32_t number, uint32_t channel, uint32_t height, uint32_t width);

    void SetNumber(const uint32_t number);

    uint32_t GetNumber() const;

    void SetChannel(const uint32_t channel);

    uint32_t GetChannel() const;

    void SetHeight(const uint32_t height);

    uint32_t GetHeight() const;

    void SetWidth(const uint32_t width);

    uint32_t GetWidth() const;

    bool IsEqual(uint32_t number, uint32_t channel, uint32_t height, uint32_t width);
private:
    uint32_t n_ {0};
    uint32_t c_ {0};
    uint32_t h_ {0};
    uint32_t w_ {0};
};

class AiTensorImpl {
public:
    AiTensorImpl();
    virtual ~AiTensorImpl();

    AIStatus Init(const TensorDimension* dim);

    AIStatus Init(const TensorDimension* dim, HIAI_DataType pdataType);

    AIStatus Init(const void* data, const TensorDimension* dim, HIAI_DataType pdataType);

    AIStatus Init(const NativeHandle& handle, const TensorDimension* dim, HIAI_DataType pdataType);

    AIStatus Init(uint32_t number, uint32_t height, uint32_t width, AiTensorImage_Format format);

    virtual void* GetBuffer() const;

    virtual uint32_t GetSize() const;

    AIStatus SetTensorDimension(const TensorDimension* dim);

    TensorDimension GetTensorDimension() const;
    void* GetTensorBuffer() const;

private:
    friend class AiModelMngerClientImpl;
    std::shared_ptr<INDTensorBuffer> tensor_;
    NDTensorDesc desc_;
};

}
#endif